{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom Chains\n",
    "\n",
    "Sometimes you might not have a predefined solution for your problem, you might want more control or just \n",
    "get a better understanding of what chains actually do."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv, find_dotenv\n",
    "load_dotenv(find_dotenv())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.callbacks.stdout import StdOutCallbackHandler\n",
    "from langchain.chat_models.openai import ChatOpenAI\n",
    "from langchain.prompts.prompt import PromptTemplate\n",
    "from langchain.chains import LLMChain\n",
    "\n",
    "\n",
    "chain = LLMChain(\n",
    "    prompt=PromptTemplate.from_template(\"tell us a joke about {topic}\"),\n",
    "    llm=ChatOpenAI(),\n",
    ")\n",
    "\n",
    "chain.run(\"chains\", callbacks=[StdOutCallbackHandler()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "from typing import Any, Dict, List, Optional\n",
    "\n",
    "from pydantic import Extra\n",
    "\n",
    "from langchain.schema import BaseLanguageModel\n",
    "from langchain.callbacks.manager import (\n",
    "    AsyncCallbackManagerForChainRun,\n",
    "    CallbackManagerForChainRun,\n",
    ")\n",
    "from langchain.chains.base import Chain\n",
    "from langchain.prompts.base import BasePromptTemplate\n",
    "\n",
    "\n",
    "class WikipediaArticleChain(Chain):\n",
    "    \"\"\"\n",
    "    Custom chain for generating a brief Wikipedia article on a given topic.\n",
    "    \"\"\"\n",
    "\n",
    "    prompt: BasePromptTemplate\n",
    "    llm: BaseLanguageModel\n",
    "    output_key: str = \"article\"\n",
    "\n",
    "    class Config:\n",
    "        \"\"\"Configuration for this pydantic object.\"\"\"\n",
    "\n",
    "        extra = Extra.forbid\n",
    "        arbitrary_types_allowed = True\n",
    "\n",
    "    @property\n",
    "    def input_keys(self) -> List[str]:\n",
    "        return self.prompt.input_variables\n",
    "\n",
    "    @property\n",
    "    def output_keys(self) -> List[str]:\n",
    "        return [self.output_key]\n",
    "\n",
    "    def _call(\n",
    "        self,\n",
    "        inputs: Dict[str, Any],\n",
    "        run_manager: Optional[CallbackManagerForChainRun] = None,\n",
    "    ) -> Dict[str, str]:\n",
    "\n",
    "        prompt_value = self.prompt.format_prompt(**inputs)\n",
    "        response = self.llm.generate_prompt(\n",
    "            [prompt_value], callbacks=run_manager.get_child() if run_manager else None\n",
    "        )\n",
    "\n",
    "        if run_manager:\n",
    "            run_manager.on_text(\"Generated Wikipedia article on given topic\")\n",
    "\n",
    "        return {self.output_key: response.generations[0][0].text}\n",
    "\n",
    "    async def _acall(\n",
    "        self,\n",
    "        inputs: Dict[str, Any],\n",
    "        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
    "    ) -> Dict[str, str]:\n",
    "        raise NotImplementedError(\"Does not support async\")\n",
    "\n",
    "    @property\n",
    "    def _chain_type(self) -> str:\n",
    "        return \"wikipedia_article_chain\"\n",
    "\n",
    "\n",
    "from langchain.callbacks.stdout import StdOutCallbackHandler\n",
    "from langchain.chat_models.openai import ChatOpenAI\n",
    "from langchain.prompts.prompt import PromptTemplate\n",
    "\n",
    "wikipedia_prompt = PromptTemplate.from_template(\"Write a brief Wikipedia-style article on the topic {topic}\")\n",
    "\n",
    "chain = WikipediaArticleChain(\n",
    "    prompt=wikipedia_prompt,\n",
    "    llm=ChatOpenAI(),\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Running our custom chain with a given topic\n",
    "result = chain.run(\"quantum physics\", callbacks=[StdOutCallbackHandler()])\n",
    "\n",
    "print(type(result))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dict = chain({\"topic\": \"Quantum Physics\"}, callbacks=[StdOutCallbackHandler()])\n",
    "print(type(output_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
